{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Digit Classification\n",
    "In this tutorial, we will use graph kernels to classify hand-written digits. Each datapoint is a 8x8 image of a digit. There are 10 classes in total, and each class refers to a digit. The dataset is contained in scikit-learn and we can load it using the *load_digits()* function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAABnCAYAAACjHpHIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAI8ElEQVR4nO3db4xUVx3G8ecRrI3hz0K0L1rbLNgXNUaXAGnSaCxESDBV2UaLiW0iGAuJbySaBl7UBrSJkFQFTTRb/xFTNYAvICUxFQygbWwt6JJYjRpgg0hL0sJS2pIq8vPFHeykKXvP7Nw5M3P3+0ma7LC/uefsb7vP3L1zzx5HhAAAebyt2xMAgKmE0AWAjAhdAMiI0AWAjAhdAMiI0AWAjLoauran2X7F9i1V1oLedhr97Zy697al0G18cVf/u2L7UtPje1sdPCL+GxEzIuJUlbVVsP2A7RdsX7D9Q9vXdXi8KdFb20O2f237JduXOz1e07hTpb+ft/1H2y/bPm37G7andXjMqdLbe23/rZEJZ23/xPaMlo8z2cURtsckfSEiDkxQMz0isv1gVcX2XZJ+JGmppLOS9ko6HBEPZhp/TPXt7fsk3SFpXNKuiJjehTmMqb79/aKkY5KelXSDpH2SHouIRzKNP6b69vYWSa9FxIu2Z0r6gaQzEfHlVo5T6eUF2w/b3mn7F7YvSrrP9h22n7Y9bvt529+x/fZG/XTbYXuw8fixxud/Zfui7d/bntdqbePzH7P998ar0ndtP2V7deKX8jlJj0bEXyPinKSHJaU+tyPq0ttGT38s6S8VtqdtNerv9yLiqYj4d0SclvRzSR+qrlOtq1FvT0XEi03/dEXSra32oxPXdO9W8Y2eLWmnpMuSviTpXSq++SskrZvg+Z+V9FVJcyWdkvT1Vmtt3yBpl6QHGuOelHT71SfZntf4Zt94jeO+X8XZwlXHJN1ke/YEc8mhDr3tZXXs70ckPZdY20m16K3tO21fkPSypE9K2jbBPN5SJ0L3yYh4PCKuRMSliHg2Ip6JiMsRcULSo5LunOD5v4yIIxHxH0k/k7RgErUflzQaEXsbn/u2pP+/QkXEyYgYiIgz1zjuDEkXmh5f/XjmBHPJoQ697WW16q/t+yV9UNK3ymozqEVvI+JwRMyWdLOkR1SEeks6cT3tn80PbN8m6ZuSFkl6Z2PMZyZ4/gtNH7+mIgBbrb2xeR4REbZPl878Da9ImtX0eFbTv3dTHXrby2rTX9ufUnGG99HGJbJuq01vG889bfuAirP328vqm3XiTPfN78yNSPqzpFsjYpakhyS5A+M2e17Se64+sG1JN7Xw/OckDTU9HpL0r4gYr2Z6k1aH3vayWvTXxRvB35d0V0T0wqUFqSa9fZPpkt7b6pNy3Kc7U8Wv56+6eOd6ous2VdknaaHtT9ieruLa0btbeP5PJd1v+zbbcyU9KGlH9dNsW9/11oXrJV3XeHy9O3w7Xhv6sb/LVfz/e3dEHO3QHKvQj729z/bNjY8HVfwm8ZtWJ5EjdL+i4m6Aiype3XZ2esCIOCvpMyquZb2k4tXoT5JelyTb813cQ/iWF8wjYp+K6z2/lTQm6R+SvtbpeU9C3/W2UX9JxZuT0xof99SdDE36sb8PqXiz6gm/ca/s452e9yT0Y28/IOlp269KelLFb8Qtv1hM+j7dfuLi5vAzkj4dEb/r9nzqhN52Fv3tnG71trZ/e8H2Ctuzbb9Dxe0jlyX9ocvTqgV621n0t3N6obe1DV1JH5Z0QsUtISskDUfE692dUm3Q286iv53T9d5OicsLANAr6nymCwA9h9AFgIzKVqRVcu1h9+7dpTUbNmworVm+fHnSeFu2bCmtmTNnTtKxErRzQ3e2aztLliwprRkfT1v7sXnz5tKalStXJh0rwWT7m623hw4dKq0ZHh5OOtaCBROtbk0fL1FXe7t169bSmo0bN5bWzJs3r7RGko4eLb9tOUcucKYLABkRugCQEaELABkRugCQEaELABkRugCQEaELABkRugCQUZbtr1MWPpw8ebK05vz580njzZ07t7Rm165dpTX33HNP0nj9YGBgoLTm8OHDScc6ePBgaU2FiyO6anR0tLRm6dKlpTWzZ6ftaTo2NpZU1+tSFjWk/AyOjIyU1qxbl/YnbVMWRyxbtizpWO3gTBcAMiJ0ASAjQhcAMiJ0ASAjQhcAMiJ0ASAjQhcAMiJ0ASCjthdHpNxwnLLw4fjx46U18+fPT5pTyg4TKfPul8URKTfwV7jbQNLuBnWxZ8+e0pqhoaHSmtSdI1J25egHa9euLa1JWTS1aNGi0prUnSNyLHxIwZkuAGRE6AJARoQuAGRE6AJARoQuAGRE6AJARoQuAGRE6AJARm0vjkjZzWHhwoWlNakLH1Kk3FDdL7Zt21Zas2nTptKaCxcuVDCbwpIlSyo7Vq9bv359ac3g4GAlx5Hqs+NGys/ziRMnSmtSFlalLnpIyao5c+YkHasdnOkCQEaELgBkROgCQEaELgBkROgCQEaELgBkROgCQEaELgBklGVxRMpODlXqlZugq5ByU/3q1atLa6r8esfHxys7VjelfB0pi1NSdpdItWPHjsqO1etSFlCcO3eutCZ1cURK3YEDB0pr2v1Z4kwXADIidAEgI0IXADIidAEgI0IXADIidAEgI0IXADIidAEgI0IXADJqe0VayuqMo0ePtjuMpLSVZpJ05MiR0ppVq1a1O50pa3R0tLRmwYIFGWbSnpRtjrZv317JWKmr1gYGBioZry5S8iVlFZkkrVu3rrRm69atpTVbtmxJGu9aONMFgIwIXQDIiNAFgIwIXQDIiNAFgIwIXQDIiNAFgIwIXQDIqO3FESlbbqQsVti9e3clNak2bNhQ2bHQn1K2OTp06FBpzbFjx0prhoeHE2YkrVy5srRmzZo1lRyn2zZu3Fhak7LFTuqiqf3795fW5Fg0xZkuAGRE6AJARoQuAGRE6AJARoQuAGRE6AJARoQuAGRE6AJARlkWR6T8NfaUxQqLFy9OmlNVO1X0i5TdBlJult+7d2/SeCkLBlIWHnRbyu4WKbtkpNSk7FIhpX0PBgcHS2v6YXFEyq4Qa9eurWy8lIUPIyMjlY13LZzpAkBGhC4AZEToAkBGhC4AZEToAkBGhC4AZEToAkBGhC4AZOSI6PYcAGDK4EwXADIidAEgI0IXADIidAEgI0IXADIidAEgo/8BRc5/fWgGnPgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn.datasets import load_digits\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "digits = load_digits()\n",
    "\n",
    "images = digits.images\n",
    "n_samples, height, width = images.shape\n",
    "y = digits.target\n",
    "for i in range(4):\n",
    "    plt.subplot(1, 4, i + 1)\n",
    "    plt.axis('off')\n",
    "    plt.imshow(images[i], cmap=plt.cm.gray_r, interpolation='nearest')\n",
    "    plt.title('Training: %i' % y[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now transform the digits into graphs and perform graph classification. We will transform each image into a square grid graph. Each pixel corresponds to a node. Two nodes are connected by an edge if they are direct neighbors. For instance, pixel (0,0) is connected to pixel (1,0) and to pixel (0,1). Note also that the node attributes (i.e., pixel intensities) are transformed into discrete labels, and that the edges are also assigned dicrete labels (i.e., labels 1 and 2 for the horizontal and vertical edges, respectively)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transforming images to graphs\n"
     ]
    }
   ],
   "source": [
    "from grakel import Graph\n",
    "\n",
    "print(\"Transforming images to graphs\")\n",
    "graphs = list()\n",
    "edges = list()\n",
    "edge_labels = dict()\n",
    "for i in range(height):\n",
    "    for j in range(width):\n",
    "        if j < width-1:\n",
    "            edges.append((i*height+j, i*height+j+1))\n",
    "            edge_labels[(i*height+j, i*height+j+1)] = 1\n",
    "        if i < height-1:\n",
    "            edges.append((i*height+j, (i+1)*height+j))\n",
    "            edge_labels[(i*height+j, (i+1)*height+j)] = 2\n",
    "\n",
    "for i in range(n_samples):\n",
    "    node_labels = dict()\n",
    "    for j in range(height):\n",
    "        for k in range(width):\n",
    "            node_labels[j*height+k] = int(images[i,j,k]/4)\n",
    "    \n",
    "    graphs.append(Graph(edges, node_labels=node_labels, edge_labels=edge_labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We next split the dataset into a training and a test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Splitting dataset into train/test (1000/100 instances)\n"
     ]
    }
   ],
   "source": [
    "print(\"Splitting dataset into train/test (1000/100 instances)\")\n",
    "graphs_train, graphs_test = graphs[:1000], graphs[1000:1100]\n",
    "y_train, y_test = y[:1000], y[1000:1100]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We finally use the *neighborhood subgraph pairwise distance* kernel to generate the two kernel matrices and classify the digits. Note that this kernel can handle both node labels and edge labels. To perform classification, we pass the two matrices on to the SVM classifier, and we compute the accuracy of the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing kernel matrics\n",
      "\n",
      "done in 33.948s\n",
      "\n",
      "Classifying digits\n",
      "\n",
      "Classification accuracy: 0.90\n"
     ]
    }
   ],
   "source": [
    "from time import time\n",
    "from grakel.kernels import NeighborhoodSubgraphPairwiseDistance\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "# Initialize neighborhood subgraph pairwise distance kernel\n",
    "gk = NeighborhoodSubgraphPairwiseDistance(r=3, d=2)\n",
    "\n",
    "print(\"Computing kernel matrics\\n\")\n",
    "t0 = time()\n",
    "K_train = gk.fit_transform(graphs_train)\n",
    "K_test = gk.transform(graphs_test)\n",
    "print(\"done in %0.3fs\\n\" % (time() - t0))\n",
    "\n",
    "print(\"Classifying digits\\n\")\n",
    "# Initialize SVM\n",
    "clf = SVC(kernel='precomputed')\n",
    "\n",
    "# Fit on the train Kernel\n",
    "clf.fit(K_train, y_train)\n",
    "\n",
    "# Predict and test.\n",
    "y_pred = clf.predict(K_test)\n",
    "\n",
    "# Calculate accuracy of classification.\n",
    "print(\"Classification accuracy: %0.2f\" % accuracy_score(y_test, y_pred))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
